Skip to content

support weight-update in disaggregated mode using sglang#1766

Open
PengchengShi00 wants to merge 4 commits into
InternLM:rl_designfrom
PengchengShi00:support-weight-update-sglang-dis
Open

support weight-update in disaggregated mode using sglang#1766
PengchengShi00 wants to merge 4 commits into
InternLM:rl_designfrom
PengchengShi00:support-weight-update-sglang-dis

Conversation

@PengchengShi00
Copy link
Copy Markdown

  1. 将TrainingWorker中有关权重同步的函数抽取到UpdateWeighter类中
  2. 将PR中有关训推共卡权重同步的优化更新到UpdateWeighter类中,Update weight persist buffer #1653
  3. 修复sglang在跑训推分离时,rollout和train所占GPU没有分开部署的 bug
  4. 更新配置GSM8KJudgerConfig的配置参数
  5. 增加训推分离模式下,使用sglang作为推理后端时的权重同步
    a. 创建训练 ranks 之间使用的 gloo group,训推分离权重同步时通过该group做 barrier
    b. 创建了一个 NCCL process group,用来将训练 rank0 把 bucket 后的权重 broadcast 给 SGLang rollout ranks:

@CyCle1024 CyCle1024 self-requested a review May 8, 2026 08:39
Comment thread examples/v1/config/rl_disagg_single.py
Comment thread xtuner/v1/train/rl_trainer.py Outdated
@@ -1155,5 +1158,7 @@ async def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict):
self.fake_update_weights()

def fake_update_weights(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename function, if this is a real udpate weight operation, please remove fake

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the latest commit: fake_update_weights has been renamed to update_weights.

Comment on lines +9 to +14
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
TEST_DIR = Path(__file__).resolve().parent
if str(TEST_DIR) not in sys.path:
sys.path.insert(0, str(TEST_DIR))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unnecessary in ci ut test.

Comment thread xtuner/v1/rl/rollout/sglang.py Outdated

def pause_generation(self):
return self._make_request("pause_generation")
return self._make_request("pause_generation", {"mode": "retract"})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment explaining extra param.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have added a comment in the latest commit.

SGLang PauseGeneration supports three modes:

  • abort: drop both waiting and running requests
  • retract: keep waiting/running requests and generated tokens, release KV cache, and recompute KV on resume
  • in_place: keep waiting/running requests, generated tokens, and KV cache, and resume directly

I also changed the mode to abort. Before update_weights, send_abort_request has already been issued, so there should be no pending requests to preserve. In this case, abort is sufficient and makes the intended behavior clearer.

Comment on lines +309 to +315
ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers])
train_controller.update_weights()
first_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers])

ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers])
train_controller.update_weights()
second_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hash logic is about verifying the deterministic operation of update weight. But hashable result should be used when comparing training side sent state_dict and rollout side received state_dict.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit, I added per-bucket hash checks for both the training-side sent state_dict and the rollout-side received state_dict.

To support this comparison, I also needed a small SGLang-side change so the rollout side can return the received bucket hash. I have rebuilt the docker image with that SGLang patch applied.

In the latest commit, the unit tests now verify:

  1. The rollout output remains unchanged for the same input before and after the weight update.
  2. For each bucket, the training-side sent state_dict hash matches the rollout-side received state_dict hash.
  3. Across two consecutive weight updates, the training-side sent bucket state_dict hashes remain identical.

Comment thread xtuner/v1/rl/trainer/update_weighter.py Outdated
self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False)
del state_dict, name_list, param_list

if self.rollout_cfg_info["backend"] == "pytorch" and final_update:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.rollout_cfg_info["backend"] == "pytorch" and final_update:
if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update:

)

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_update_weight_and_generate(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a disaggregate case, and duplicated of the case in tests/rl/test_update_weight.py, it should be removed or forced skipping for now.

from xtuner.v1.utils import ray_method

TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}]
MODEL_PATH = os.environ.get("MODEL_PATH") or os.environ.get("QWEN3_VL_DENSE_PATH")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should remove os.environ.get("MODEL_PATH") here, just use original CI ENV

)

# training config
model_cfg = get_model_config_from_hf(Path(MODEL_PATH))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use specific model here for using QWEN3_VL_DENSE_PATH, see tests/rl/test_update_weight.py

def setUpClass(cls) -> None:
if MODEL_PATH is None:
raise unittest.SkipTest("MODEL_PATH is not set")
os.environ["XTUNER_USE_FA3"] = "1"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL_CUMEM_ENABLE=0 is required in my test environment, we may add it here or in the Actor creation stage by runtime_env of ray.remote

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL_CUMEM_ENABLE=0 is also required in my test environment. The latest commit has already added this environment variable in the unit tests.

@CyCle1024
Copy link
Copy Markdown
Collaborator

BTW, fix CI lint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants